#  -*- coding: utf-8 -*-
#  Copyright (c) 2025-2025 Huawei Technologies Co., Ltd.
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#  http://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.


"""
msmodelslim.utils.dag_utils.model_structure_process 模块的单元测试
"""

import unittest
from unittest.mock import Mock

from ascend_utils.core.dag.dag_node import DagNode
from msmodelslim.utils.dag_utils.model_infos import ModuleType
from msmodelslim.utils.dag_utils.model_structure_process import StructureProcess


class MockDagNode:
    def __init__(self, name, op_type, in_features=0, out_features=0, inputs=None, input_nodes=None, output_nodes=None):
        self.name_in_network = name
        self.op_type = op_type
        self.node = Mock()
        self.node.in_features = in_features
        self.node.out_features = out_features
        self.inputs = inputs or []
        self.input_nodes = input_nodes or []
        self.output_nodes = output_nodes or []


class TestStructureProcess(unittest.TestCase):
    """测试StructureProcess类"""

    def test_is_ffn_matmul_with_ffn_matmul_num_2_true_case(self):
        """测试is_ffn_matmul函数，ffn_matmul_num=2，返回True的情况"""
        matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
        matmul2 = MockDagNode("matmul2", "Linear", in_features=256, out_features=128)

        matmul_list = [matmul1, matmul2]
        result = StructureProcess.is_ffn_matmul(matmul_list, 2)
        self.assertTrue(result)

    def test_is_ffn_matmul_with_ffn_matmul_num_2_false_case(self):
        """测试is_ffn_matmul函数，ffn_matmul_num=2，返回False的情况"""
        matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
        matmul2 = MockDagNode("matmul2", "Linear", in_features=128, out_features=64)

        matmul_list = [matmul1, matmul2]
        result = StructureProcess.is_ffn_matmul(matmul_list, 2)
        self.assertFalse(result)

    def test_is_ffn_matmul_with_wrong_length(self):
        """测试is_ffn_matmul函数，长度不匹配的情况"""
        matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
        matmul_list = [matmul1]
        result = StructureProcess.is_ffn_matmul(matmul_list, 2)
        self.assertFalse(result)

    def test_is_ffn_matmul_with_ffn_matmul_num_3_true_case1(self):
        """测试is_ffn_matmul函数，ffn_matmul_num=3，返回True的情况1"""
        matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
        matmul2 = MockDagNode("matmul2", "Linear", in_features=256, out_features=128)
        matmul3 = MockDagNode("matmul3", "Linear", in_features=64, out_features=32)

        matmul_list = [matmul1, matmul2, matmul3]
        result = StructureProcess.is_ffn_matmul(matmul_list, 3)
        self.assertTrue(result)

    def test_is_ffn_matmul_with_ffn_matmul_num_3_true_case2(self):
        """测试is_ffn_matmul函数，ffn_matmul_num=3，返回True的情况2"""
        matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
        matmul2 = MockDagNode("matmul2", "Linear", in_features=256, out_features=128)
        matmul3 = MockDagNode("matmul3", "Linear", in_features=256, out_features=128)

        matmul_list = [matmul1, matmul2, matmul3]
        result = StructureProcess.is_ffn_matmul(matmul_list, 3)
        self.assertTrue(result)

    def test_is_ffn_matmul_with_ffn_matmul_num_3_false_case(self):
        """测试is_ffn_matmul函数，ffn_matmul_num=3，返回False的情况"""
        matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
        matmul2 = MockDagNode("matmul2", "Linear", in_features=128, out_features=64)
        matmul3 = MockDagNode("matmul3", "Linear", in_features=32, out_features=16)

        matmul_list = [matmul1, matmul2, matmul3]
        result = StructureProcess.is_ffn_matmul(matmul_list, 3)
        self.assertFalse(result)

    def test_is_ffn_matmul_with_unsupported_ffn_matmul_num(self):
        """测试is_ffn_matmul函数，不支持的ffn_matmul_num"""
        matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
        matmul2 = MockDagNode("matmul2", "Linear", in_features=256, out_features=128)
        matmul3 = MockDagNode("matmul3", "Linear", in_features=128, out_features=256)
        matmul4 = MockDagNode("matmul4", "Linear", in_features=256, out_features=128)

        matmul_list = [matmul1, matmul2, matmul3, matmul4]

        with self.assertRaises(Exception) as context:
            StructureProcess.is_ffn_matmul(matmul_list, 4)
        self.assertIn("unsupported ffn_matmul_num: 4", str(context.exception))

    def test_mhsa_matmul_process_with_proj_mat_out_features_times_3_equal_qkv_mat_out_features(self):
        """测试mhsa_matmul_process函数，proj_mat.out_features * 3 == qkv_mat.out_features的情况"""
        proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=128)
        qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=384)

        matmul_list = [proj_mat, qkv_mat]
        qkv_list = []
        proj_list = []

        StructureProcess.mhsa_matmul_process(matmul_list, qkv_list, proj_list)

        self.assertEqual(qkv_list, [["qkv"]])
        self.assertEqual(proj_list, ["proj"])

    def test_mhsa_matmul_process_with_proj_mat_out_features_equal_qkv_mat_out_features(self):
        """测试mhsa_matmul_process函数，proj_mat.out_features == qkv_mat.out_features的情况"""
        input_node = MockDagNode("input", "Add")
        qkv1 = MockDagNode("q", "Linear", in_features=128, out_features=256)
        qkv2 = MockDagNode("k", "Linear", in_features=128, out_features=256)
        qkv3 = MockDagNode("v", "Linear", in_features=128, out_features=256)

        input_node.output_nodes = [qkv1, qkv2, qkv3, MockDagNode("other", "ReLU")]

        qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=256)
        qkv_mat.input_nodes = [input_node]
        qkv_mat.inputs = [Mock()]

        proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=256)

        matmul_list = [proj_mat, qkv_mat]
        qkv_list = []
        proj_list = []

        StructureProcess.mhsa_matmul_process(matmul_list, qkv_list, proj_list)

        self.assertEqual(qkv_list, [["q", "k", "v"]])
        self.assertEqual(proj_list, ["proj"])

    def test_mhsa_matmul_process_with_wrong_length(self):
        """测试mhsa_matmul_process函数，matmul_list长度不正确"""
        matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
        matmul_list = [matmul1]
        qkv_list = []
        proj_list = []

        StructureProcess.mhsa_matmul_process(matmul_list, qkv_list, proj_list)

        self.assertEqual(qkv_list, [])
        self.assertEqual(proj_list, [])

    def test_mhsa_matmul_process_with_qkv_mat_inputs_not_equal_1(self):
        """测试mhsa_matmul_process函数，qkv_mat.inputs长度不为1"""
        qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=256)
        qkv_mat.inputs = [Mock(), Mock()]

        proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=256)

        matmul_list = [proj_mat, qkv_mat]
        qkv_list = []
        proj_list = []

        StructureProcess.mhsa_matmul_process(matmul_list, qkv_list, proj_list)

        self.assertEqual(qkv_list, [])
        self.assertEqual(proj_list, [])

    def test_mhsa_matmul_process_with_wrong_qkv_mat_list_length(self):
        """测试mhsa_matmul_process函数，qkv_mat_list长度不为3"""
        input_node = MockDagNode("input", "Add")
        qkv1 = MockDagNode("q", "Linear", in_features=128, out_features=256)
        qkv2 = MockDagNode("k", "Linear", in_features=128, out_features=256)

        input_node.output_nodes = [qkv1, qkv2, MockDagNode("other", "ReLU")]

        qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=256)
        qkv_mat.input_nodes = [input_node]
        qkv_mat.inputs = [Mock()]

        proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=256)

        matmul_list = [proj_mat, qkv_mat]
        qkv_list = []
        proj_list = []

        StructureProcess.mhsa_matmul_process(matmul_list, qkv_list, proj_list)

        self.assertEqual(qkv_list, [])
        self.assertEqual(proj_list, [])

    def test_mhsa_matmul_ln_process_with_proj_mat_out_features_times_3_equal_qkv_mat_out_features(self):
        """测试mhsa_matmul_ln_process函数，proj_mat.out_features * 3 == qkv_mat.out_features的情况"""
        proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=128)
        qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=384)

        ln1 = MockDagNode("ln", "LayerNorm")

        matmul_list = [proj_mat, qkv_mat]
        ln_list = [ln1]
        qkv_list = []
        proj_list = []
        mhsa_ln_list = []

        StructureProcess.mhsa_matmul_ln_process(matmul_list, ln_list, qkv_list, proj_list, mhsa_ln_list)

        self.assertEqual(qkv_list, [["qkv"]])
        self.assertEqual(proj_list, ["proj"])
        self.assertEqual(mhsa_ln_list, ["ln"])

    def test_mhsa_matmul_ln_process_with_proj_mat_out_features_equal_qkv_mat_out_features(self):
        """测试mhsa_matmul_ln_process函数，proj_mat.out_features == qkv_mat.out_features的情况"""
        input_node = MockDagNode("input", "Add")
        qkv1 = MockDagNode("q", "Linear", in_features=128, out_features=256)
        qkv2 = MockDagNode("k", "Linear", in_features=128, out_features=256)
        qkv3 = MockDagNode("v", "Linear", in_features=128, out_features=256)

        input_node.output_nodes = [qkv1, qkv2, qkv3, MockDagNode("other", "ReLU")]

        qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=256)
        qkv_mat.input_nodes = [input_node]
        qkv_mat.inputs = [Mock()]

        proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=256)
        ln1 = MockDagNode("ln", "LayerNorm")

        matmul_list = [proj_mat, qkv_mat]
        ln_list = [ln1]
        qkv_list = []
        proj_list = []
        mhsa_ln_list = []

        StructureProcess.mhsa_matmul_ln_process(matmul_list, ln_list, qkv_list, proj_list, mhsa_ln_list)

        self.assertEqual(qkv_list, [["q", "k", "v"]])
        self.assertEqual(proj_list, ["proj"])
        self.assertEqual(mhsa_ln_list, ["ln"])

    def test_mhsa_matmul_ln_process_with_wrong_length(self):
        """测试mhsa_matmul_ln_process函数，matmul_list长度不正确"""
        matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
        matmul_list = [matmul1]
        ln_list = [MockDagNode("ln", "LayerNorm")]
        qkv_list = []
        proj_list = []
        mhsa_ln_list = []

        StructureProcess.mhsa_matmul_ln_process(matmul_list, ln_list, qkv_list, proj_list, mhsa_ln_list)

        self.assertEqual(qkv_list, [])
        self.assertEqual(proj_list, [])
        self.assertEqual(mhsa_ln_list, [])

    def test_mhsa_matmul_ln_process_with_qkv_mat_inputs_not_equal_1(self):
        """测试mhsa_matmul_ln_process函数，qkv_mat.inputs长度不为1"""
        qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=256)
        qkv_mat.inputs = [Mock(), Mock()]

        proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=256)
        ln1 = MockDagNode("ln", "LayerNorm")

        matmul_list = [proj_mat, qkv_mat]
        ln_list = [ln1]
        qkv_list = []
        proj_list = []
        mhsa_ln_list = []

        StructureProcess.mhsa_matmul_ln_process(matmul_list, ln_list, qkv_list, proj_list, mhsa_ln_list)

        self.assertEqual(qkv_list, [])
        self.assertEqual(proj_list, [])
        self.assertEqual(mhsa_ln_list, [])

    def test_mhsa_matmul_ln_process_with_wrong_qkv_mat_list_length(self):
        """测试mhsa_matmul_ln_process函数，qkv_mat_list长度不为3"""
        input_node = MockDagNode("input", "Add")
        qkv1 = MockDagNode("q", "Linear", in_features=128, out_features=256)
        qkv2 = MockDagNode("k", "Linear", in_features=128, out_features=256)

        input_node.output_nodes = [qkv1, qkv2, MockDagNode("other", "ReLU")]

        qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=256)
        qkv_mat.input_nodes = [input_node]
        qkv_mat.inputs = [Mock()]

        proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=256)
        ln1 = MockDagNode("ln", "LayerNorm")

        matmul_list = [proj_mat, qkv_mat]
        ln_list = [ln1]
        qkv_list = []
        proj_list = []
        mhsa_ln_list = []

        StructureProcess.mhsa_matmul_ln_process(matmul_list, ln_list, qkv_list, proj_list, mhsa_ln_list)

        self.assertEqual(qkv_list, [])
        self.assertEqual(proj_list, [])
        self.assertEqual(mhsa_ln_list, [])


if __name__ == '__main__':
    unittest.main()